# =============================================================================
# Reprojecting HAWC+ polarization cubes in galactic coordinates
# =============================================================================
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy import wcs
from aplpy import FITSFigure
import math
import numpy as np
import os
from reproject import reproject_interp, reproject_exact

# =============================================================================
# Constants
# =============================================================================
l_cnp = 122.93193212 # Galactic longitude of the celestial north pole
b_cnp = 27.12835496 # Galactic latitude of the celestial north pole

# =============================================================================
# Function to reproject from the International Celestial Reference System
# to the galactic frame of reference
# =============================================================================
def ICRStoGal(filename, SourceName, GalFolder=''):
    # Creating the temporaty fits file with the modified headers
	# that will be used for the reprojection
    tempfilename = temporary_FITS(filename)
    
	# Opening the necessary fits files
    hawcfile = fits.open(filename)
    tempfile = fits.open(tempfilename)
        
	# Creating the reprojected fits file
    newfile = reprojection(hawcfile, tempfile)
    print(newfile.info()) 	
    # Recalculating the polarization properties of the reprojected file
	# and converting to galactic coordinates
    workfile = recalculate_polarization(newfile)
    finalfile = convert_polarization(workfile)
    
	
    # Writing the reprojected and corrected FITS file to disk
    newfilename = GalFolder+SourceName+'_Gal.fits'
    finalfile.writeto(newfilename, overwrite='true')


	# Closing the fits files
    hawcfile.close()
    tempfile.close()
    # Deleting the temporary fits file
    os.remove(tempfilename)	

	# Returning HDUList object file name
    return newfilename

# =============================================================================
# Function to recalculate the polarization for the data using the 
# International Celestial Reference System
# =============================================================================
def ICRStoICRS(filename, SourceName, ICRSFolder=''):
	# Opening the necessary fits files
	hawcfile = fits.open(filename)
	# A temporary fits file is not needed here because no changes are done
	# to the final headers

	# Recalculating the polarization properties of the reprojected file
	# and converting to galactic coordinates
	workfile = recalculate_polarization(hawcfile)

	# Writing the reprojected and corrected FITS file to disk
	newfilename = ICRSFolder+SourceName+'_RA.fits'
	workfile.writeto(newfilename, overwrite='true')
	print(workfile.info()) 	

	# Closing the fits files
	hawcfile.close()

	# Returning HDUList object file name
	return newfilename

# =============================================================================
# Function to hack all the header of all the extensions of a FITS file
# =============================================================================
def convert_headers(headersfile):
	# HAWC+ files are expected to have 16 extensions
	for i in range(0,16):
		hack_header(headersfile[i].header)	
	return

# =============================================================================
# Function to rotate the polarization angles and re-calculate Stokes Q & U
# =============================================================================
def convert_polarization(workfile):
	# Creating WCS object
	WCSinfo = wcs.WCS(workfile[10].header)
	# Rotating the polarization angles
	array_size = np.shape(workfile[10].data)
	PolO_gal = np.zeros((array_size[0],array_size[1]))
	for i in range(0,array_size[0]):
		for j in range(0,array_size[1]):
			# Finding the galactic longitude and latitude in degrees
			PixCoord = SkyCoord.from_pixel(i,j, WCSinfo)
			l = PixCoord.l.deg
			b = PixCoord.b.deg
			# Calculating the new polarization angle
			PolO_gal[i,j] = workfile[10].data[i,j] + rotation_angle(l,b)

	# Fixing range of angles
	PolO_gal[np.where(PolO_gal > 90.0)] = PolO_gal[np.where(PolO_gal > 90.0)] - 180.0
	# Polarization angle B
	PolB_gal = PolO_gal + 90.0
	PolB_gal[np.where(PolB_gal > 90.0)] = PolB_gal[np.where(PolB_gal > 90.0)] - 180.0
	
	# Calculating the new Stokes Q and U
	PolIp = workfile[13].data # Polarized intensity (for clarity)
	PolQ_gal = PolIp * np.cos(2.0*np.radians(PolO_gal)) # Stokes Q
	PolU_gal = PolIp * np.sin(2.0*np.radians(PolO_gal)) # Stokes U
	# The uncertainties for Stokes Q and U are assumed to be the same as before
	# as they should be quite close in value regardless of the reference frame
	# due to how the data is reduced

	# Replacing data in FITS container
	workfile[2].data = PolQ_gal
	workfile[4].data = PolU_gal	
	workfile[10].data = PolO_gal
	workfile[11].data = PolB_gal
	# Returning FITS container for ease of use (optional)
	return workfile

# =============================================================================
# Function to hack a header from ICRS to Gal
# =============================================================================
def hack_header(header):
	# Modifying the size of the fits file
	naxis1 = header['NAXIS1'] + 100	# Reading axis length and increasing it
	naxis2 = header['NAXIS2'] + 100 # Reading axis length and increasing it
	# Setting axis size to hypothenuse
	diagonal = (naxis1**2.0+naxis2**2.0)**0.5
	# Changing the NAXIS keywords to get a square container
	header['NAXIS1'] = diagonal
	header['NAXIS2'] = diagonal

	# Changing the coordinate system keywords
	header['CTYPE1'] = 'GLON-TAN' 
	header['CTYPE2'] = 'GLAT-TAN'

	# Establish new central pixel
	header['CRPIX1'] = math.ceil(diagonal/2)
	header['CRPIX2'] = math.ceil(diagonal/2)

	# Convert value of the central pixel
	crval1 = header['CRVAL1']
	crval2 = header['CRVAL2']
	OldCoord = SkyCoord(ra=crval1,dec=crval2,frame='icrs', unit='deg')
	NewCoord = OldCoord.galactic 
	header['CRVAL1'] = NewCoord.l.deg
	header['CRVAL2'] = NewCoord.b.deg
	
	# Replacing coordinate system name
	header['WCSNAME'] = 'Galactic Coordinates'

	header.comments['EQUINOX'] = 'Equinox of observations'
	
	# Adding comment
	header['COMMENT'] = 'Data was converted from Celestial to Galactic coordinates' 
	return

# =============================================================================
# Function to calculate the polarization properties from Stokes I, Q, and U
# from a fits container opened from an HAWC+ fits file
# See HAWC+ white paper for structure (Gordon et al. 2018)
# https://ui.adsabs.harvard.edu/abs/2018arXiv181103100G/abstract
# !! The calculation below assumes zero covariance as information is missing !!
# =============================================================================
def recalculate_polarization(newfile):
    # Creating human-readable names for arrays
	StokesI = newfile[0].data # Stokes I
	dI = newfile[1].data # Error Stokes I
	StokesQ = newfile[2].data # Stokes Q
	dQ = newfile[3].data # Error Stokes Q
	StokesU = newfile[4].data # Stokes U
	dU = newfile[5].data # Error Stokes U

	# Biased polarized intensity 
	PI_biased = (StokesQ**2.0 + StokesU**2.0)**0.5
	# PI uncertainties
	dPI = ((StokesQ*dQ)**2.0 + (StokesU*dU)**2.0)**0.5/PI_biased 
	# Mapping where the intensity is smaller than the uncertainty
	pimask = np.where(PI_biased < dPI)
	dPI_debias = np.copy(dPI)
	dPI_debias[pimask] = 0.0 #np.nan
	# De-biased polarized intensity PI 	
	PolPI = (PI_biased**2.0 - dPI_debias**2.0)**0.5  
	PolPI[pimask] = 0.0

	# Polarization fraction P
	P_biased = 100.0*PI_biased/StokesI
	PolP = 100.0*PolPI/StokesI
    # Uncertainty of polarization fraction
	dP = np.absolute(P_biased*((dPI/PI_biased)**2.0 + (dI/StokesI)**2.0)**0.5)
	# Eliminating unphysical results (setting values to 0.0 or 100.0)
	P_biased[np.where(P_biased < 0.0)] = 0.0
	P_biased[np.where(P_biased > 100.0)] = 100.0
	PolP[np.where(PolP < 0.0)] = 0.0
	PolP[np.where(PolP > 100.0)] = 100.0
	dP[np.where(dP > 100.0)] = 100.0
    
    # Polarization angle O
	PolO = (0.5*180.0/math.pi)*np.arctan2(StokesU, StokesQ)
    # Uncertainties dO
	dO = (0.5*180.0/math.pi)*((StokesQ*dU)**2.0 + 
                              (StokesU*dQ)**2.0)**0.5/PI_biased**2.0
    # Polarization angle B
	PolB = PolO + 90.0
	PolB[np.where(PolB > 90.0)] = PolB[np.where(PolB > 90.0)] - 180.0

	# Creating new fits container with updated polarization in ICRS
	workfile = fits.HDUList()
	# Stokes I and dI
	workfile.append(fits.ImageHDU(data=StokesI, header=newfile[0].header))
	workfile.append(fits.ImageHDU(data=dI.data, header=newfile[1].header))
	# Stokes Q and dQ
	workfile.append(fits.ImageHDU(data=StokesQ.data, header=newfile[2].header))
	workfile.append(fits.ImageHDU(data=dQ.data, header=newfile[3].header))
	# Stokes U and dU
	workfile.append(fits.ImageHDU(data=StokesU.data, header=newfile[4].header))
	workfile.append(fits.ImageHDU(data=dU.data, header=newfile[5].header))
	# Image mask
	workfile.append(fits.ImageHDU(data=newfile[6].data, header=newfile[6].header))
	# Polarization fraction
	workfile.append(fits.ImageHDU(data=P_biased, header=newfile[7].header))
	workfile.append(fits.ImageHDU(data=PolP, header=newfile[8].header))
	workfile.append(fits.ImageHDU(data=dP, header=newfile[9].header))
	# Polarization angle
	workfile.append(fits.ImageHDU(data=PolO, header=newfile[10].header))
	workfile.append(fits.ImageHDU(data=PolB, header=newfile[11].header))
	workfile.append(fits.ImageHDU(data=dO, header=newfile[12].header))
	# Polarized intensity
	workfile.append(fits.ImageHDU(data=PI_biased, header=newfile[13].header))
	workfile.append(fits.ImageHDU(data=dPI, header=newfile[14].header))
	workfile.append(fits.ImageHDU(data=PolPI, header=newfile[15].header))
	# Adding catalog at the end
	workfile.append(newfile[16])
	workfile.append(newfile[17])

    # Returning FITS container with updated polarization values
	return workfile

# =============================================================================
# Function to create a fits container with galactic coordinate headers 
# =============================================================================
def reprojection(hawcfile, tempfile):
	# Creating multi-extension FITS file
	newfile = fits.HDUList()
	for i in range (0,16):
		# Reprojecting the file for a given header
		projfile, footprint = reproject_exact(hawcfile[i], tempfile[i].header)
		# Creating the HDU exension
		listfile = fits.ImageHDU(data=projfile, header=tempfile[i].header)
		# Appending current extension to the new fits file
		newfile.append(listfile)
	
	# Adding the original tables to the final file
	# These may need to be modified to account for the change in coordinates
	newfile.append(hawcfile[16])
	newfile.append(hawcfile[17])
	return newfile

# =============================================================================
# Rotation of polarization angles from ICRS to Gal
# from Appenzeller (1968) - see Stephens+ (2022)
# =============================================================================
def rotation_angle(l,b):
	l_rad = np.radians(l)
	b_rad = np.radians(b)
	l_cnp_rad = np.radians(l_cnp)
	b_cnp_rad = np.radians(b_cnp)
	numerator = np.sin(l_cnp_rad-l_rad)
	#denominator = np.tan(b_cnp_rad)*np.cos(b_rad)-np.sin(b_rad)*np.cos(l_cnp_rad-l) # Old, probably wrong
	denominator = np.tan(b_cnp_rad)*np.cos(b_rad)-np.sin(b_rad)*np.cos(l_cnp_rad-l_rad)
	rotation_angle = np.degrees(np.arctan2(numerator,denominator))
	return rotation_angle

# =============================================================================
# Function to create a temporary FITS file for the reprojection
# This step is necessary to avoid unusual errors that would occur when
# using reproject with a copy() of the original FITS file instead
# =============================================================================
def temporary_FITS(filename):
	# Opening FITS file
	headersfile = fits.open(filename)
	# Hacking the headers of the temporary file
	convert_headers(headersfile)
	# Creating temporary HDUList object
	newfile = fits.HDUList()
	# Creating reference numpy array
	Axis1 = int(headersfile[0].header['NAXIS1'])
	Axis2 = int(headersfile[0].header['NAXIS2'])
	narray = np.zeros((Axis1,Axis2))
	
	# Creating multi-extension FITS file by appending each extension's header
	for i in range (0,16):
		newfile.append(fits.ImageHDU(data=narray, header=headersfile[i].header))

	# Saving temporary FITS file on disk
	tempfilename = 'temp.fits'
	newfile.writeto('temp.fits', overwrite='true')
	# Closing the FITS file to avoid unwanted errors from .copy()
	headersfile.close()
	
	return tempfilename

# =============================================================================
# Magnetic field map on Stokes I total intensity
# =============================================================================
def BvImap(Iref, Pref, IdI=10.0, PdP=3.0, Pmax=30.0,
           StepVec=3, ScaleVec=2, Colorbar='top', Linewidth=2.0):
    # Plotting the INTENSITY POLARIZATION MAP
    BvI_plot = FITSFigure(Iref)
    BvI_plot.show_colorscale(cmap='Greys')
    BvI_plot.add_colorbar(pad=0.25)
    BvI_plot.colorbar.set_location(Colorbar)
    BvI_plot.colorbar.set_axis_label_text('Stokes $I$ Intensity (mJy arcsec$^{-2}$)')
    BvI_plot.add_beam(facecolor='white', edgecolor='blue',
                        linewidth=2, pad=1, corner='bottom left') # HAWC+ Beam

    # Creating a new fits object for APLpy's FITSFigure method to recognize
    # for the polarization fraction P
    pref = fits.PrimaryHDU(data=Pref[8].data, header=Pref[0].header)
    pmap = pref.copy()
    # for the magnetic field angle B
    oref = fits.PrimaryHDU(data=Pref[11].data, header=Pref[0].header)
    omap = oref.copy()
    # Creating a mask to hide polarization vectors with low signal-to-noise ratios
    imask_01 = np.where(Pref[0].data/Pref[1].data < IdI) # Total intensity SNR threshold
    imask_02 = np.where(Pref[0].data < 0) # Total intensity positive threshold
    pmask = np.where(Pref[8].data/Pref[9].data < PdP) # Polarization SNR threshold
    pmaxmask = np.where(Pref[8].data > Pmax) # Polarization SNR threshold
    # Forcing the polarization vectors to share the same amplitude
    pmap.data[np.where(Pref[8].data > 0.0)] = 1.0
    # Masking all the indices for  which the selection criteria failed
    pmap.data[imask_01] = np.nan
    pmap.data[imask_02] = np.nan
    pmap.data[pmask] = np.nan
    pmap.data[pmaxmask] = np.nan

    # Plotting the polarization vectors
    BvI_plot.show_vectors(pmap, omap, scale=ScaleVec, step=StepVec, color='red', zorder=3, alpha=0.8, linewidth=Linewidth)

    return BvI_plot